{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 1. Install dependent libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# ! pip install -q paddleseg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 2. Unzip datasets\n",
    "**Data organization**\n",
    "```\n",
    "dataset\n",
    "  ├ img\n",
    "  │  ├ build0.jpg\n",
    "  │  └ ....jpg\n",
    "  └ gt\n",
    "     ├ build0.png\n",
    "     └ ....png\n",
    "```\n",
    "**Label pixel**\n",
    "```\n",
    "0 background\n",
    "1 building\n",
    "```\n",
    "**Image size**\n",
    "```\n",
    "512x512\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# ! mkdir -p dataset  # create a folder for save dataset\n",
    "# ! unzip -oq data.zip -d dataset  # unzip the zip data to the dataset folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "\n",
    "PATH = os.getcwd()\n",
    "print(PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 3. Split datasets and create data_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "import random\n",
    "\n",
    "\n",
    "def create_list(data_path: str, val_num: int=2000) -> None:\n",
    "    \"\"\" create list.\n",
    "    args:\n",
    "        data_path (str): dataset folder.\n",
    "        val_num (int, optional): number of evaluation data.\n",
    "    \"\"\"\n",
    "    image_path = osp.join(data_path, \"img\")\n",
    "    data_names = os.listdir(image_path)\n",
    "    random.shuffle(data_names)  # scramble data\n",
    "    with open(os.path.join(data_path, \"train_list.txt\"), \"w\") as tf:\n",
    "        with open(os.path.join(data_path, \"val_list.txt\"), \"w\") as vf:\n",
    "            for idx, data_name in enumerate(data_names):\n",
    "                img = os.path.join(\"img\", data_name)\n",
    "                lab = os.path.join(\"gt\", data_name.replace(\"jpg\", \"png\"))\n",
    "                if idx < val_num:\n",
    "                    vf.write(img + \" \" + lab + \"\\n\")\n",
    "                else:\n",
    "                    tf.write(img + \" \" + lab + \"\\n\")\n",
    "    print(\"Data list generation completed\")\n",
    "\n",
    "\n",
    "create_list(osp.join(PATH, \"dataset\"), 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 4. Create PaddlePaddle Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import paddleseg.transforms as T\n",
    "from paddleseg.datasets import Dataset\n",
    "\n",
    "\n",
    "# build the training set\n",
    "train_transforms = [T.RandomHorizontalFlip(),\n",
    "                    T.RandomVerticalFlip(),\n",
    "                    T.RandomRotation(),\n",
    "                    T.RandomScaleAspect(),\n",
    "                    T.RandomBlur(),\n",
    "                    T.Resize(target_size=(512, 512)),\n",
    "                    T.Normalize()]\n",
    "train_dataset = Dataset(transforms=train_transforms,\n",
    "                        dataset_root=osp.join(PATH, \"dataset\"),\n",
    "                        num_classes=2,\n",
    "                        mode=\"train\",\n",
    "                        train_path=osp.join(PATH, \"dataset/train_list.txt\"),\n",
    "                        separator=\" \")\n",
    "\n",
    "# build validation set\n",
    "val_transforms = [T.Resize(target_size=(512, 512)),\n",
    "                  T.Normalize()]\n",
    "val_dataset = Dataset(transforms=val_transforms,\n",
    "                      dataset_root=osp.join(PATH, \"dataset\"),\n",
    "                      num_classes=2,\n",
    "                      mode=\"val\",\n",
    "                      val_path=osp.join(PATH, \"dataset/val_list.txt\"),\n",
    "                      separator=\" \")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 5. Select model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import paddle\n",
    "from paddleseg.models import OCRNet, HRNet_W18\n",
    "\n",
    "\n",
    "model = OCRNet(num_classes=2,\n",
    "               backbone=HRNet_W18(),\n",
    "               backbone_indices=[0],\n",
    "               pretrained=osp.join(PATH, \"weight/ocrnet_hrnet_w18_512x512_rs_building.pdparams\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 6. Set super-parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from paddleseg.models.losses import MixedLoss, BCELoss, DiceLoss\n",
    "\n",
    "\n",
    "base_lr = 3e-5\n",
    "epochs = 5\n",
    "batch_size = 1\n",
    "\n",
    "iters = epochs * len(train_dataset) // batch_size\n",
    "lr = paddle.optimizer.lr.PolynomialDecay(base_lr, decay_steps=iters // epochs, power=0.9)\n",
    "optimizer = paddle.optimizer.Adam(lr, parameters=model.parameters())\n",
    "losses = {}\n",
    "losses[\"types\"] = [MixedLoss([BCELoss(), DiceLoss()], [1, 1])] * 2\n",
    "losses[\"coef\"] = [1] * 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 7. Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from paddleseg.core import train\n",
    "\n",
    "\n",
    "train(model=model,\n",
    "      train_dataset=train_dataset,\n",
    "      val_dataset=val_dataset,\n",
    "      optimizer=optimizer,\n",
    "      save_dir=osp.join(PATH, \"output\"),\n",
    "      iters=iters,\n",
    "      batch_size=batch_size,\n",
    "      save_interval=iters // 5,\n",
    "      log_iters=10,\n",
    "      num_workers=0,\n",
    "      losses=losses,\n",
    "      use_vdl=True)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "779fd89e231d504f3761036cc866499d7be8785d89f9802b1abf9e02a6b7fe30"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
